{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "run_control": {
     "marked": true
    }
   },
   "source": [
    "## Description：\n",
    "这个是参考了一个GitHub纯手撸了一遍， 通过这个能够加深对FFM内部的理解， 具体内容可以参考[https://github.com/Orisun/ffm](https://github.com/Orisun/ffm).<br><br>\n",
    "由于FFM对于数据集的存储格式有特殊的要求， 所以这里用的自己生成的训练集和测试集。 数据格式: label field_name, feature_name, value。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T02:13:57.876947Z",
     "start_time": "2020-09-26T02:13:57.873910Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义Singleton（单例模式)\n",
    "这个继承type， 继承type的类在python的概念中叫做元类，在python中所有的对象都是类，包括类本身。元类就是创建类的类。 <br><br>type和object类的关系: type是所有类的元类（包括object），object是所有类的父类（包括type）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:23:26.650187Z",
     "start_time": "2020-09-26T03:23:26.644210Z"
    }
   },
   "outputs": [],
   "source": [
    "class Singleton(type):\n",
    "    def __init__(cls, class_name, base_classes, attr_dict):\n",
    "        cls.__instance = None\n",
    "        super(Singleton, cls).__init__(class_name, base_classes, attr_dict)\n",
    "    \n",
    "    def __call__(cls, *args, **kwargs):\n",
    "        if cls.__instance is None:\n",
    "            cls.__instance = super(Singleton, cls).__call__(*args, **kwargs)\n",
    "        else:\n",
    "            return cls.__instance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义逻辑回归类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:23:27.673457Z",
     "start_time": "2020-09-26T03:23:27.665472Z"
    }
   },
   "outputs": [],
   "source": [
    "class Logistic(object):\n",
    "    __metaclass__ = Singleton   # 单例\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.exp_max = 10.0\n",
    "        self.exp_scale = 0.001\n",
    "        self.exp_intv = int(self.exp_max / self.exp_scale)   # 10000\n",
    "        self.exp_table = [0.0] * self.exp_intv      # 10000个0\n",
    "        for i in range(self.exp_intv):   # [10000]\n",
    "            x = self.exp_scale * i # [0.001 * i for i in range(0, 10000)]   x 是0-10\n",
    "            exp = math.exp(x)  # 1-e^10\n",
    "            self.exp_table[i] = exp / (1.0 + exp)       # 归一化缩放  [0, 1]     \n",
    "    \n",
    "    # 这里是手动求逻辑回归函数的值\n",
    "    def decide_by_table(self, x):\n",
    "        \"\"\"查表获得逻辑回归的函数值\"\"\"\n",
    "        if x == 0:\n",
    "            return 0.5\n",
    "        i = int(np.nan_to_num(abs(x) / self.exp_scale))\n",
    "        y = self.exp_table[min(i, self.exp_intv - 1)]\n",
    "        if x > 0:\n",
    "            return y\n",
    "        else:\n",
    "            return 1.0 - y\n",
    "    \n",
    "    def decide_by_tanh(self, x):\n",
    "        '''直接使用1.0 / (1.0 + np.exp(-x))容易发警告“RuntimeWarning: overflowencountered in exp”，\n",
    "           转换成如下等价形式后算法会更稳定\n",
    "        '''\n",
    "        return 0.5 * (1 + np.tanh(0.5 * x))\n",
    "\n",
    "    def decide(self, x):\n",
    "        \"\"\"原始的sigmoid函数\"\"\"\n",
    "        return 1.0 / (1.0 + np.exp(-x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义FFM模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:24:22.043920Z",
     "start_time": "2020-09-26T03:24:22.024916Z"
    }
   },
   "outputs": [],
   "source": [
    "class FFM_Node(object):\n",
    "    '''\n",
    "    通常x是高维稀疏向量，所以用链表来表示一个x，链表上的每个节点是个3元组(j,f,v)\n",
    "    '''\n",
    "    __slots__ = ['j', 'f', 'v']    # 按照元组不是字典的方式存储类的成员属性\n",
    "    \n",
    "    def __init__(self, j, f, v):\n",
    "        \"\"\"\n",
    "            j: Feature index (0-n-1)\n",
    "            f: field index(0-m-1)\n",
    "            v: value\n",
    "        \"\"\"\n",
    "        self.j = j\n",
    "        self.f = f\n",
    "        self.v = v\n",
    "    \n",
    "class FFM(object):\n",
    "    def __init__(self, m, n, k, eta, lambd):\n",
    "        \"\"\"\n",
    "            m: Number of fields\n",
    "            n: Number of features\n",
    "            k: Number of latent factors\n",
    "            eta: learning rate\n",
    "            lambd: regularization coefficient\n",
    "        \"\"\"\n",
    "        self.m = m\n",
    "        self.n = n\n",
    "        self.k = k\n",
    "        \n",
    "        #超参数\n",
    "        self.eta = eta\n",
    "        self.lambd = lambd\n",
    "        \n",
    "        # 初始化三维权重矩阵w~U(0, 1/sqrt(k))\n",
    "        self.w = np.random.rand(n, m, k) / math.sqrt(k)\n",
    "        \n",
    "        # 初始化累积梯度平方和， AdaGrad时要用到\n",
    "        self.G = np.ones(shape=(n, m, k), dtype=np.float64)\n",
    "        self.log = Logistic()\n",
    "    \n",
    "    # 这个是计算第三项\n",
    "    def phi(self, node_list):\n",
    "        \"\"\"\n",
    "        特征组合式的线性加权求和\n",
    "        param node_list: 用链表存储x中的非0值\n",
    "        \"\"\"\n",
    "        z = 0.0\n",
    "        for a in range(len(node_list)):\n",
    "            node1 = node_list[a]\n",
    "            j1 = node1.j\n",
    "            f1 = node1.f\n",
    "            v1 = node1.v\n",
    "            for b in range(a+1, len(node_list)):\n",
    "                node2 = node_list[b]\n",
    "                j2 = node2.j\n",
    "                f2 = node2.f\n",
    "                v2 = node2.v\n",
    "                w1 = self.w[j1, f2]\n",
    "                w2 = self.w[j2, f1]\n",
    "                z += np.dot(w1, w2) * v1 * v2\n",
    "        \n",
    "        return z\n",
    "    \n",
    "    \n",
    "    def predict(self, node_list):\n",
    "        \"\"\"\n",
    "        输入x， 预测y的值\n",
    "        \"\"\"\n",
    "        z = self.phi(node_list)\n",
    "        y = self.log.decide_by_tanh(z)\n",
    "        return y\n",
    "\n",
    "    # 随机梯度下降\n",
    "    def sgd(self, node_list, y):\n",
    "        \"\"\"\n",
    "        根据一个样本更新模型参数：\n",
    "        node_list: 链表存储x中的非0值\n",
    "        y: 正样本1， 负样本-1\n",
    "        \"\"\"\n",
    "        kappa = -y / (1+math.exp(y*self.phi(node_list)))    # 论文里面的那个导数\n",
    "        for a in range(len(node_list)):\n",
    "            node1 = node_list[a]\n",
    "            j1 = node1.j\n",
    "            f1 = node1.f\n",
    "            v1 = node1.v\n",
    "            for b in range(a+1, len(node_list)):\n",
    "                node2 = node_list[b]\n",
    "                j2 = node2.j\n",
    "                f2 = node2.f\n",
    "                v2 = node2.v\n",
    "                c = kappa * v1 * v2      # 这是求导数\n",
    "                \n",
    "                # self.w[j1,f2]和self.w[j2,f1]是向量，导致g_j1_f2和g_j2_f1也是向量\n",
    "                g_j1_f2 = self.lambd * self.w[j1, f2] + c * self.w[j2, f1]\n",
    "                g_j2_f1 = self.lambd * self.w[j2, f1] + c * self.w[j1, f2]\n",
    "                \n",
    "                # 计算各个维度上的梯度累积平方和\n",
    "                self.G[j1, f2] += g_j1_f2 ** 2\n",
    "                self.G[j2, f1] += g_j2_f1 ** 2\n",
    "                \n",
    "                # Adagrad 算法\n",
    "                self.w[j1, f2] -= self.eta / np.sqrt(self.G[j1, f2]) * g_j1_f2  # sqrt(G)作为分母，所以G必须是大于0的正数\n",
    "                self.w[j2, f1] -= self.eta / np.sqrt(\n",
    "                    self.G[j2, f1]) * g_j2_f1  # math.sqrt()只能接收一个数字作为参数，而numpy.sqrt()可以接收一个array作为参数，表示对array中的每个元素分别开方\n",
    "    \n",
    "    # 训练\n",
    "    def train(self, sample_generator, max_echo, max_r2):\n",
    "        \"\"\"\n",
    "        根据一堆样本训练模型\n",
    "        sample_generator: 样本生成器，每次yield (node_list, y)，node_list中存储的是x的非0值。通常x要事先做好归一化，即模长为1，这样精度会略微高一点\n",
    "        max_echo: 最大迭代次数\n",
    "        max_r2: 拟合系数r2达到阈值时即可终止学习\n",
    "        \"\"\"\n",
    "        for itr in range(max_echo):\n",
    "            print(\"echo: \", itr)\n",
    "            y_sum = 0.0\n",
    "            y_sqare_sum = 0.0\n",
    "            err_square_sum = 0.0    # 误差平方和\n",
    "            population = 0   # 样本总数\n",
    "            for node_list, y in sample_generator:\n",
    "                y = 0.0  if y == -1 else y    # 真实的y取值为{-1,1}，而预测的y位于(0,1)，计算拟合效果时需要进行统一\n",
    "                self.sgd(node_list, y)\n",
    "                y_hat = self.predict(node_list)\n",
    "                y_sum += y\n",
    "                y_sqare_sum += y ** 2\n",
    "                err_square_sum += (y-y_hat) ** 2\n",
    "                population += 1\n",
    "            \n",
    "            var_y = y_sqare_sum - y_sum * y_sum / population  # y的方差\n",
    "            r2 = 1 - err_square_sum / var_y\n",
    "            print(\"r2: \", r2)\n",
    "            if r2 > max_r2:\n",
    "                print(\"r2 have reach\", r2)\n",
    "                break\n",
    "        \n",
    "    # 模型保存\n",
    "    def save_model(self, outfile):\n",
    "        '''\n",
    "        序列化模型\n",
    "        :param outfile:\n",
    "        :return:\n",
    "        '''\n",
    "        np.save(outfile, self.w)\n",
    "\n",
    "    def load_model(self, infile):\n",
    "        '''\n",
    "        加载模型\n",
    "        :param infile:\n",
    "        :return:\n",
    "        '''\n",
    "        self.w = np.load(infile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:23:33.249012Z",
     "start_time": "2020-09-26T03:23:33.240993Z"
    }
   },
   "outputs": [],
   "source": [
    "import re\n",
    "class Sample(object):\n",
    "    def __init__(self, infile):\n",
    "        self.infile = infile\n",
    "        self.regex = re.compile(\"\\\\s+\")\n",
    "\n",
    "    def __iter__(self):\n",
    "        with open(self.infile, 'r') as f_in:\n",
    "            for line in f_in:\n",
    "                arr = self.regex.split(line.strip())\n",
    "                if len(arr) >= 2:\n",
    "                    y = float(arr[0])\n",
    "                    assert math.fabs(y) == 1\n",
    "                    node_list = []\n",
    "                    square_sum = 0.0\n",
    "                    for i in range(1, len(arr)):\n",
    "                        brr = arr[i].split(\",\")\n",
    "                        if len(brr) == 3:\n",
    "                            j = int(brr[0])\n",
    "                            f = int(brr[1])\n",
    "                            v = float(brr[2])\n",
    "                            square_sum += v * v\n",
    "                            node_list.append(FFM_Node(j, f, v))\n",
    "                    if square_sum > 0:\n",
    "                        norm = math.sqrt(square_sum)\n",
    "                        # 把模长缩放到1\n",
    "                        normed_node_list = [FFM_Node(ele.j, ele.f, ele.v / norm) for ele in node_list]\n",
    "                        yield (normed_node_list, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:24:26.655003Z",
     "start_time": "2020-09-26T03:24:26.625099Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "echo:  0\n",
      "r2:  -0.16471447551138296\n",
      "echo:  1\n",
      "r2:  -0.16470221265226326\n",
      "echo:  2\n",
      "r2:  -0.16468995371475814\n",
      "echo:  3\n",
      "r2:  -0.1646776986975249\n",
      "echo:  4\n",
      "r2:  -0.1646654475992202\n",
      "echo:  5\n",
      "r2:  -0.1646532004185024\n",
      "echo:  6\n",
      "r2:  -0.16464095715402927\n",
      "echo:  7\n",
      "r2:  -0.1646287178044601\n",
      "echo:  8\n",
      "r2:  -0.16461648236845394\n",
      "echo:  9\n",
      "r2:  -0.16460425084467079\n",
      "echo:  10\n",
      "r2:  -0.1645920232317708\n",
      "echo:  11\n",
      "r2:  -0.16457979952841506\n",
      "echo:  12\n",
      "r2:  -0.16456757973326486\n",
      "echo:  13\n",
      "r2:  -0.16455536384498193\n",
      "echo:  14\n",
      "r2:  -0.1645431518622289\n",
      "echo:  15\n",
      "r2:  -0.16453094378366884\n",
      "echo:  16\n",
      "r2:  -0.16451873960796504\n",
      "echo:  17\n",
      "r2:  -0.16450653933378123\n",
      "echo:  18\n",
      "r2:  -0.16449434295978183\n",
      "echo:  19\n",
      "r2:  -0.1644821504846321\n",
      "echo:  20\n",
      "r2:  -0.16446996190699736\n",
      "echo:  21\n",
      "r2:  -0.16445777722554378\n",
      "echo:  22\n",
      "r2:  -0.16444559643893775\n",
      "echo:  23\n",
      "r2:  -0.1644334195458459\n",
      "echo:  24\n",
      "r2:  -0.16442124654493617\n",
      "echo:  25\n",
      "r2:  -0.16440907743487654\n",
      "echo:  26\n",
      "r2:  -0.1643969122143356\n",
      "echo:  27\n",
      "r2:  -0.164384750881982\n",
      "echo:  28\n",
      "r2:  -0.16437259343648547\n",
      "echo:  29\n",
      "r2:  -0.16436043987651683\n"
     ]
    }
   ],
   "source": [
    "# 设置参数   5个特征， 2个域， 2维的k\n",
    "n = 5\n",
    "m = 2\n",
    "k = 2\n",
    "\n",
    "# 路径\n",
    "train_file = \"dataset/train.txt\"\n",
    "valid_file = \"dataset/test.txt\"\n",
    "model_file = \"ffm.npy\"\n",
    "\n",
    "# 超参数\n",
    "eta = 0.01\n",
    "lambd = 1e-2\n",
    "max_echo = 30\n",
    "max_r2 = 0.9\n",
    "\n",
    "# 训练模型，并保存模型参数\n",
    "sample_generator = Sample(train_file)\n",
    "ffm = FFM(m, n, k, eta, lambd)\n",
    "ffm.train(sample_generator, max_echo, max_r2)\n",
    "ffm.save_model(model_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:24:30.271164Z",
     "start_time": "2020-09-26T03:24:30.264182Z"
    }
   },
   "outputs": [],
   "source": [
    "# 加载模型，并计算在验证集上的拟合效果\n",
    "ffm.load_model(model_file)\n",
    "valid_generator = Sample(valid_file)\n",
    "y_sum = 0.0\n",
    "y_square_sum = 0.0\n",
    "err_square_sum = 0.0  # 误差平方和\n",
    "population = 0  # 样本总数\n",
    "for node_list, y in valid_generator:\n",
    "    y = 0.0 if y == -1 else y  # 真实的y取值为{-1,1}，而预测的y位于(0,1)，计算拟合效果时需要进行统一\n",
    "    y_hat = ffm.predict(node_list)\n",
    "    y_sum += y\n",
    "    y_square_sum += y ** 2\n",
    "    err_square_sum += (y - y_hat) ** 2\n",
    "    population += 1\n",
    "var_y = y_square_sum - y_sum * y_sum / population  # y的方差\n",
    "r2 = 1 - err_square_sum / var_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T03:24:50.808517Z",
     "start_time": "2020-09-26T03:24:50.800029Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.027269624863842212"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r2"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
